#include <os/debug.h>
#include <os/diskman.h>
#include <os/initcall.h>
#include <os/driver.h>
#include <os/hardirq.h>
#include <lib/stdlib.h>
#include <lib/string.h>
#include <sys/ioctl.h>
#include <driver/cdrom.h>

static uint8_t irq_invoked = 0;

static int IdePrintError(device_extension_t *extension, uint32_t err);
static int IdeReadSector(device_extension_t *extension, uint32_t lba, void *buff, uint32_t count);
static void ResetIrq();
static void IdeWaitIrq();
static void AtapiHandler(irqno_t irq, void *data);
static void SendCmd(ide_channel_t *channel, uint32_t cmd);
static int ATAPIDeviceTransfer(device_extension_t *extension, uint8_t op, uint32_t lba, void *buff, uint32_t count);

// send cmd to ide channel
static void SendCmd(ide_channel_t *channel, uint32_t cmd)
{
    Out8(channel->iobase + ATA_REG_CMD, cmd);
}

static int IdeReadSector(device_extension_t *extension, uint32_t lba, void *buff, uint32_t count)
{
    uint8_t err;

    if (lba + count >= extension->size && extension->type == ATAPI_DEVICE)
    {
        KPrint(PRINT_ERR "%s: ide read err!\n", __func__);
        return -1;
    }

    // read sector from ide
    for (int i = 0; i < count; i++)
    {
        if (extension->type == ATAPI_DEVICE)
        {
            err = ATAPIDeviceTransfer(extension, IDE_READ, lba + i, buff + i * SECTOR_SIZE, 1);
            if (IdePrintError(extension, err))
            {
                KPrint(PRINT_ERR "%s ide read err!\n", __func__);
                return -1;
            }
        }
    }
    return 0;
}

static int ATAPIDeviceTransfer(device_extension_t *extension, uint8_t op, uint32_t lba, void *buff, uint32_t count)
{
    uint8_t mode, head;
    uint32_t size;
    int i;
    uint8_t bag[8];
    ide_channel_t *channel = extension->channel;
    uint8_t atapi_readpack[12] = {ATAPI_CMD_READ, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
    uint16_t *half_buff = (uint16_t *)buff;

    // select disk
    Out8(channel->iobase + ATA_REG_HDDSEL, ATA_DEVICE_MASK(0, extension->driver, 0));
    Out8(channel->iobase + ATA_REG_FEATURE, 0); // use PIO mode
    // set ATAPI sector size
    Out8(channel->iobase + ATA_REG_LBA1, ATAPI_SECTOR_SIZE & 0xff);
    Out8(channel->iobase + ATA_REG_LBA2, ATAPI_SECTOR_SIZE >> 8);

    // send "PACKET" command to enable ATAPI packet transfer
    Out8(channel->iobase + ATA_REG_CMD, ATA_CMD_PACKET);

    // reset irq
    ResetIrq();
    // wait disk BUSY clear and DRQ set
    while ((In8(channel->iobase + ATA_REG_STATUS) & ATA_STATUS_BUSY) || !(In8(channel->iobase + ATA_REG_STATUS) & ATA_STATUS_DRQ))
        CpuIdle();

    // send ATAPI command packet
    atapi_readpack[2] = (lba >> 0x18) & 0xff;
    atapi_readpack[3] = (lba >> 0x10) & 0xff;
    atapi_readpack[4] = (lba >> 0x08) & 0xff;
    atapi_readpack[5] = (lba >> 0x00) & 0xff;
    for (int i = 0; i < 6; i++)
    {
        Out16(channel->iobase + ATA_REG_DATA, *(uint16_t *)&atapi_readpack[i]);
    }
    // wait irq
    IdeWaitIrq();
    while (!(In8(channel->iobase + ATA_REG_STATUS) & ATA_STATUS_READY))
        CpuIdle();

    // get actual size
    size = (In8(channel->iobase + ATA_REG_LBA2) << 8) | (In8(channel->iobase + ATA_REG_LBA1) & 0xff);
    // read data from hardware register
    for (i = 0; i < (size / 2); i++)
    {
        while (!(In8(ATA_REG_STATUS) & ATA_STATUS_READY))
            CpuIdle();
        *half_buff++ = In8(channel->iobase + ATA_REG_DATA);
    }
    // wait BUSY and DRQ clear,indicate that command finished
    while ((In8(channel->iobase + ATA_REG_STATUS) & ATA_STATUS_BUSY) || (In8(channel->iobase + ATA_REG_STATUS) & ATA_STATUS_DRQ))
        CpuIdle();
    return 0;
}

static uint16_t ReadData(ide_channel_t *channel)
{
    return In16(channel->iobase + ATA_REG_DATA);
}

static void WriteData(ide_channel_t *channel, uint16_t data)
{
    Out16(channel->iobase + ATA_REG_DATA, data);
}

static int ReadBuff(ide_channel_t *channel, uint8_t *buff, uint32_t len)
{
    int bytes = len;

    if (!buff)
        return -1;
    while (len > 0)
    {
        *buff++ = ReadData(channel);
        len--;
    }
    return bytes;
}

static int WriteBuff(ide_channel_t *channel, uint8_t *buff, uint32_t len)
{
    uint8_t *p = buff;
    int bytes = len;

    if (!buff)
        return -1;
    while (len > 0)
    {
        WriteData(channel, *p++);
        len--;
    }
    return bytes;
}

static int WriteToSector(ide_channel_t *channel, uint8_t *buff, uint32_t count)
{
    int bytes;
    if (!count)
        bytes = SECTOR_SIZE;
    else
    {
        bytes = SECTOR_SIZE * count;
    }
    return WriteBuff(channel, buff, bytes);
}

static int ReadFromSector(ide_channel_t *channel, uint8_t *buff, uint32_t count)
{
    int bytes;
    if (!count)
        bytes = SECTOR_SIZE;
    else
        bytes = SECTOR_SIZE * count;
    return ReadBuff(channel, buff, bytes);
}

// reset irq flags
static void ResetIrq()
{
    irq_invoked = 0;
}

// wait irq finished
static void IdeWaitIrq()
{
    while (!irq_invoked)
    {
        TaskYield();
    }
}

// ATAPI interrupt handler
static void AtapiHandler(irqno_t irq, void *data)
{
    irq_invoked = 1;
}

// ide print error message to screen
static int IdePrintError(device_extension_t *extension, uint32_t err)
{
    uint8_t status;

    // no avaliable error
    if (!err)
        return err;
    // device fault
    if (err == 1)
        KPrint("Device Fault\n");
    else
    {
        // error
        if (err == 2)
        {
            status = In8(extension->channel->iobase + ATA_REG_STATUS);
            if (status & ATA_ER_AMNF)
                KPrint("No address Mask Found\n");
            if (status & ATA_ER_TK0NF)
                KPrint("No Media or Media error\n");
            if (status & ATA_ER_ABRT)
                KPrint("Command Abort\n");
            if (status & ATA_ER_MCR)
                KPrint("No Media or Media error\n");
            if (status & ATA_ER_IDNF)
                KPrint("ID mask no Found\n");
            if (status & ATA_ER_MC)
                KPrint("No Media or Media error\n");
            if (status & ATA_ER_UNC)
                KPrint("Uncorrectable data error\n");
            if (status & ATA_ER_BBK)
                KPrint("Bad Sector\n");
        }
        else if (err == 3)
            KPrint("Read Notings\n");
        else if (err == 4)
            KPrint("Write Protect\n");
        else if (err == 5)
            KPrint("Timeout\n");
        // printf channel
        switch (extension->channel - channelinfo)
        {
        case ATA_PRIMARY_CHANNEL:
            KPrint("ATA Primary Channel\n");
            break;
        case ATA_SLAVE_CHANNEL:
            KPrint("ATA Slave Channel\n");
        default:
            break;
        }
        // printf driver
        switch (extension->driver)
        {
        case ATA_MASTER_DEVICE:
            KPrint("ATA Master Device\n");
            break;
        case ATA_SLAVE_DEVICE:
            KPrint("ATA Slave Device\n");
        default:
            break;
        }
    }
    return err;
}

// ide polling
// When we send a command, we should wait for
// 400 nanosecond, then read the Status port.
// If the Busy bit is on, we should read the
// status port again until the Busy bit is 0;
// then we can read the results of the command.
// This operation is called "Polling".
static int IdePolling(ide_channel_t *channel, uint32_t advance)
{
    int i;
    uint8_t status;

    // read alternate status register waitting 400ns,read once just cost 100ns
    for (i = 0; i < 4; i++)
    {
        In8(channel->iobase + ATA_REG_ALTSTU);
    }
    // wait for status register BSY to be clear
    while (In8(channel->iobase + ATA_REG_STATUS) & ATA_STATUS_BUSY)
        ;

    if (advance)
    {
        status = In8(channel->iobase + ATA_REG_STATUS);
        // error check
        if (status & ATA_STATUS_ERR)
        {
            return 2;
        }
        // check if device fault
        if (status & ATA_STATUS_DEVFAULT)
        {
            return 1;
        }
        // check if data request ready
        if (!(status & ATA_STATUS_DRQ))
        {
            return 3;
        }
    }
    return 0;
}

// send reset to devie control register
static void DriverSoftReset(ide_channel_t *channel)
{
    uint8_t data = In8(channel->iobase + ATA_REG_CTRL);
    // send reset cammand
    Out8(channel->iobase + ATA_REG_CTRL, ATA_CONTROL_SRST);
    // wait reset
    for (int i = 0; i < 50; i++)
    {
        In8(channel->iobase + ATA_REG_ALTSTU);
    }
    // reset finish,refresh status;
    Out8(channel->iobase + ATA_REG_CTRL, data);
}

// reset driver
static void ResetDriver(device_extension_t *extension)
{
    DriverSoftReset(extension->channel);
}

// select targe disk
static void SelectDisk(device_extension_t *extension, uint8_t mode, uint8_t head)
{
    Out8(extension->channel->iobase + ATA_REG_HDDSEL, ATA_DEVICE_MASK((!mode) ? 0 : 1, extension->driver, head));
    extension->channel->curactive = extension->driver;
}

// ide probe device
static int IdeProbe(device_extension_t *extension, uint32_t n)
{
    uint32_t channel_id = n / 2;
    uint32_t disk_id = n % 2;
    ide_channel_t *channel;
    char irqname[32];
    uint8_t err;
    uint8_t type;
    uint8_t cl, ch;

    // get targe device channel object
    channel = &channelinfo[channel_id];

    // init channel info and irq for to every channel
    switch (channel_id)
    {
        // primary channel
    case ATA_PRIMARY_CHANNEL:
    {
        channel->iobase = ATA_PRIMARY_CMDREG_BA;
        channel->ctrlbase = ATA_PRIMARY_ALTREG_BA;
        channel->irqno = IRQ14_HARDDISK1;
    }
    break;
    // slave channel
    case ATA_SLAVE_CHANNEL:
    {
        channel->iobase = ATA_SLAVE_CMDREG_BA;
        channel->ctrlbase = ATA_SLAVE_ALTREG_BA;
        channel->irqno = IRQ15_HARDDISK2;
    }
    break;
    }
    // register interrrupt for channel
    if (!disk_id)
    {
        sprintf(irqname, "harddisk channel%d", channel_id);
        IrqRegister(channel->irqno, AtapiHandler, IRQ_DISABLE, "harddisk", irqname, channel);
    }
    // init to 0
    channel->curop = 0;
    channel->curactive = 0;
    // set extension info
    channel->extension = extension;
    extension->channel = channel;
    extension->driver = disk_id;
    extension->info = KMemAlloc(SECTOR_SIZE);
    if (!extension->info)
    {
        KPrint(PRINT_ERR "KMemAlloc for ide device %s info faild!\n", extension->device_name);
        IrqUnregister(channel->irqno, channel);
        return -1;
    }
    // reset driver
    ResetDriver(extension);
    // select disk info
    SelectDisk(extension, 0, 0);

    int timeout = 1000; // wait timeout
    // wait disk ready
    while (!(In8(channel->iobase + ATA_REG_STATUS) & ATA_STATUS_READY) && (--timeout))
        ;
    if (timeout <= 0)
    {
        KPrint(PRINT_ERR "[ide]disk %d maybe no ready or not exist\n", n);
        IrqUnregister(channel->irqno, channel);
        return -1;
    }
    // detemine device type
    type = ATA_DEVICE;
    // send IDENTIFY cmd
    SendCmd(channel, ATA_CMD_IDENTIFY);
    err = IdePolling(channel, 1);
    if (err)
    {
        // probe ATAPI device
        cl = In8(channel->iobase + ATA_REG_LBA1);
        ch = In8(channel->iobase + ATA_REG_LBA2);
        if (cl == 0x14 && ch == 0xEB)
            type = ATAPI_DEVICE;
        else
        {
            if (cl == 0x69 && ch == 0x96)
                type = ATAPI_DEVICE;
            else
            {
                IdePrintError(extension, err);
                IrqUnregister(channel->irqno, channel);
                KMemFree(extension->info);
                return -1;
            }
        }
        SendCmd(channel, ATA_CMD_IDENTIFY_PACKET); // send ATAPI identify cmd
        err = IdePolling(channel, 1);
        if (err)
        {
            IdePrintError(extension, err);
            IrqUnregister(channel->irqno, channel);
            KMemFree(extension->info);
            return -1;
        }
    }
    extension->type = type;
    // read ata device identify
    ReadFromSector(extension, extension->info, 1);
    // set extension info
    extension->command_sets = (extension->info->cmd_set1 << 16) | extension->info->cmd_set0;
    if (extension->command_sets & (1 << 26))
    {
        // LBA48 address
        extension->size = ((uint32_t)extension->info->lba48_sectors[1] << 16) | (uint32_t)extension->info->lba48_sectors[0];
    }
    else
    {
        // CHS or LBA28
        extension->size = ((uint32_t)extension->info->lba28_sectors[0] << 16 | (uint32_t)extension->info->lba28_sectors[0]);
    }
    extension->capabilities = extension->info->compabilities0;
    extension->signature = extension->info->general_config;
    extension->exist = 1; // device exist
    extension->rwoff = 0;
    return 0;
}

static uint8_t ATAPIEject(device_extension_t *extension)
{
    int i;
    uint8_t atapi_packet[12];

    if (extension->type != ATAPI_DEVICE) // only ATAPI support eject function
        return -1;

    // eject ATAPI driver
    atapi_packet[0] = ATAPI_CMD_EJECT;
    atapi_packet[1] = 0x00;
    atapi_packet[2] = 0x00;
    atapi_packet[3] = 0x00;
    atapi_packet[4] = 0x00;
    atapi_packet[5] = 0x00;
    atapi_packet[6] = 0x00;
    atapi_packet[7] = 0x00;

    // send packet command
    SendCmd(extension->channel, ATA_CMD_PACKET);
    // reset irq and wait
    ResetIrq();
    IdePolling(extension->channel, 0);
    // write packet data
    for (i = 0; i < 7; i++)
    {
        WriteData(extension->channel, atapi_packet[i]);
    }
    // wait irq and wait busy clear
    IdeWaitIrq();
    IdePolling(extension->channel, 0);
    return 0;
}

static iostatus_t CdromOpen(device_object_t *device, io_request_t *ioreq)
{
    ioreq->io_status.status = IO_SUCCESS;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return IO_SUCCESS;
}

static iostatus_t CdromClose(device_object_t *device, io_request_t *ioreq)
{
    ioreq->io_status.status = IO_SUCCESS;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return IO_SUCCESS;
}

static iostatus_t CdromDevctl(device_object_t *device, io_request_t *ioreq)
{
    uint32_t code = ioreq->parame.devctl.code;
    uint32_t arg = ioreq->parame.devctl.arg;
    device_extension_t *extension = device->device_extension;
    iostatus_t status = IO_SUCCESS;

    switch (code)
    {
    case DISKIO_GETSIZE:
        *(uint32_t *)arg = extension->size;
        break;
    case DISKIO_SETOFF:
        extension->rwoff = *(uint32_t *)arg;
        break;
    case DISKIO_GETOFF:
        *(uint32_t *)arg = extension->rwoff;
        break;
    case DISKIO_EJECT:
        if (device->type == ATAPI_DEVICE)
            ATAPIEject(extension);
        else
            KPrint("no ATAPI device no support eject operator!\n");
        break;
    default:
        status = IO_FAILED;
        break;
    }
    ioreq->io_status.status = status;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return status;
}

static iostatus_t CdromRead(device_object_t *device, io_request_t *ioreq)
{
    device_extension_t *extension = device->device_extension;
    iostatus_t status = IO_SUCCESS;
    uint64_t off = ioreq->parame.read.offset;
    uint32_t sectors = DIV_ROUND_UP(ioreq->parame.read.len, ATAPI_SECTOR_SIZE);
    uint32_t len;
    int err;

    if (ioreq->parame.read.offset == DISKOFF_MAX)
    {
        off = extension->rwoff;
    }
    else
    {
        off = ioreq->parame.read.offset;
    }

    // ATAPI device data transfer
    if (extension->type == ATAPI_DEVICE)
    {

        len = IdeReadSector(extension, off, ioreq->sys_buff, sectors);
        if (len >= 0)
        {
            len = sectors * ATAPI_SECTOR_SIZE;
        }
        else
        {
            status = IO_FAILED;
        }
    }
    ioreq->io_status.status = status;
    ioreq->io_status.info = len;
    IoCompleteRequest(ioreq);
    return status;
}

static iostatus_t CdromEnter(driver_object_t *driver)
{
    iostatus_t status = IO_SUCCESS;
    device_object_t *device;
    device_extension_t *extension;
    int i;
    uint8_t count = 0;
    uint8_t found = *(uint8_t *)IDE_DISK_NUM + 1;
    KPrint("[ide] system found %d disks\n", found);

    for (i = 0; i < found; i++) // only probe one CD-ROM device on IDE controller
    {
        status = IoCreateDevice(driver, sizeof(device_extension_t), DEVICE_NAME, DEVICE_TYPE_DISK, &device);
        if (status != IO_SUCCESS)
        {
            KPrint("%s: create device failed!\n");
            IoDeleteDevice(device);
            return status;
        }

        // neither io mode
        device->flags = DEVICE_BUFFER_IO;
        extension = (device_extension_t *)device->device_extension;
        extension->rwoff = 0;

        if (IdeProbe(extension, i) < 0)
        {
            IoDeleteDevice(device);
            status = IO_FAILED;
        }
        else
        {
            // probe ATAPI device
            if (extension->type == ATAPI_DEVICE)
            {
                KPrint("[driver] found an CD-ROM device on IDE channel %d driver %d\n", extension->channel - channelinfo, extension->driver);
                count++;
            }
            else
            {
                // no ATAPI device
                IoDeleteDevice(device);
                status = IO_FAILED;
            }
        }
    }

    if (count < 1)
    {
        KPrint("[cdrom] no found cdrom device!\n");
        status = IO_FAILED;
    }
    else
    {
        status = IO_SUCCESS;
    }
    return status;
}

static iostatus_t CdromExit(driver_object_t *driver)
{
    device_object_t *device, *next;

    list_traversal_all_owner_to_next_safe(device, next, &driver->device_list, list)
    {
        IoDeleteDevice(device);
    }
    // delete driver name
    string_del(&driver->name);

    return IO_SUCCESS;
}

iostatus_t CdromDriverFunc(driver_object_t *driver)
{
    iostatus_t status = IO_SUCCESS;

    driver->driver_enter = CdromEnter;
    driver->driver_exit = CdromExit;

    driver->dispatch_fun[IOREQ_OPEN] = CdromOpen;
    driver->dispatch_fun[IOREQ_CLOSE] = CdromClose;
    driver->dispatch_fun[IOREQ_READ] = CdromRead;
    driver->dispatch_fun[IOREQ_DEVCTL] = CdromDevctl;

    // driver name
    string_new(&driver->name, DRIVER_NAME, DRIVER_NAME_LEN);

    return IO_SUCCESS;
}

static __init void CdromDriverEntry()
{
    /*KPrint("[driver] create cdrom driver\n");
    if (DriverObjectCreate(CdromDriverFunc) < 0)
    {
        KPrint("[driver] %s create driver failed!\n", __func__);
    }*/
}

driver_initcall(CdromDriverEntry);